{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "eUEjpHvGyjm2",
    "outputId": "1cff5ec5-cc25-4e7f-96c8-ab15d19e7e23"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting datasets\n",
      "  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m547.8/547.8 kB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.15.4)\n",
      "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.25.2)\n",
      "Collecting pyarrow>=15.0.0 (from datasets)\n",
      "  Downloading pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (40.8 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.8/40.8 MB\u001b[0m \u001b[31m12.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n",
      "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n",
      "  Downloading dill-0.3.8-py3-none-any.whl (116 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.0.3)\n",
      "Collecting requests>=2.32.2 (from datasets)\n",
      "  Downloading requests-2.32.3-py3-none-any.whl (64 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.9/64.9 kB\u001b[0m \u001b[31m7.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.4)\n",
      "Collecting xxhash (from datasets)\n",
      "  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m17.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting multiprocess (from datasets)\n",
      "  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m15.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: fsspec[http]<=2024.5.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n",
      "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.5)\n",
      "Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.23.4)\n",
      "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.1)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
      "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.2.0)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n",
      "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n",
      "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.2->datasets) (4.12.2)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.3.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.7)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.0.7)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.6.2)\n",
      "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
      "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.4)\n",
      "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n",
      "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
      "Installing collected packages: xxhash, requests, pyarrow, dill, multiprocess, datasets\n",
      "  Attempting uninstall: requests\n",
      "    Found existing installation: requests 2.31.0\n",
      "    Uninstalling requests-2.31.0:\n",
      "      Successfully uninstalled requests-2.31.0\n",
      "  Attempting uninstall: pyarrow\n",
      "    Found existing installation: pyarrow 14.0.2\n",
      "    Uninstalling pyarrow-14.0.2:\n",
      "      Successfully uninstalled pyarrow-14.0.2\n",
      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
      "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 16.1.0 which is incompatible.\n",
      "google-colab 1.0.0 requires requests==2.31.0, but you have requests 2.32.3 which is incompatible.\n",
      "ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 16.1.0 which is incompatible.\u001b[0m\u001b[31m\n",
      "\u001b[0mSuccessfully installed datasets-2.20.0 dill-0.3.8 multiprocess-0.70.16 pyarrow-16.1.0 requests-2.32.3 xxhash-3.4.1\n",
      "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.41.2)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.15.4)\n",
      "Requirement already satisfied: huggingface-hub<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.23.4)\n",
      "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.25.2)\n",
      "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.5.15)\n",
      "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n",
      "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n",
      "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.3)\n",
      "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.4)\n",
      "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.0->transformers) (2023.6.0)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.0->transformers) (4.12.2)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.7)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.6.2)\n",
      "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.3.0+cu121)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.15.4)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n",
      "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12.1)\n",
      "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n",
      "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n",
      "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2023.6.0)\n",
      "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)\n",
      "  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
      "Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)\n",
      "  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
      "Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)\n",
      "  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
      "Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)\n",
      "  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n",
      "Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)\n",
      "  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
      "Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)\n",
      "  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
      "Collecting nvidia-curand-cu12==10.3.2.106 (from torch)\n",
      "  Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
      "Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch)\n",
      "  Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
      "Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch)\n",
      "  Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
      "Collecting nvidia-nccl-cu12==2.20.5 (from torch)\n",
      "  Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)\n",
      "Collecting nvidia-nvtx-cu12==12.1.105 (from torch)\n",
      "  Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
      "Requirement already satisfied: triton==2.3.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.3.0)\n",
      "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch)\n",
      "  Downloading nvidia_nvjitlink_cu12-12.5.82-py3-none-manylinux2014_x86_64.whl (21.3 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.3/21.3 MB\u001b[0m \u001b[31m55.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n",
      "Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n",
      "Installing collected packages: nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12\n",
      "Successfully installed nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.5.82 nvidia-nvtx-cu12-12.1.105\n",
      "Collecting lancedb\n",
      "  Downloading lancedb-0.9.0-cp38-abi3-manylinux_2_28_x86_64.whl (20.9 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20.9/20.9 MB\u001b[0m \u001b[31m49.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting deprecation (from lancedb)\n",
      "  Downloading deprecation-2.1.0-py2.py3-none-any.whl (11 kB)\n",
      "Collecting pylance==0.13.0 (from lancedb)\n",
      "  Downloading pylance-0.13.0-cp39-abi3-manylinux_2_28_x86_64.whl (25.5 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m25.5/25.5 MB\u001b[0m \u001b[31m47.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hCollecting ratelimiter~=1.0 (from lancedb)\n",
      "  Downloading ratelimiter-1.2.0.post0-py3-none-any.whl (6.6 kB)\n",
      "Requirement already satisfied: requests>=2.31.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (2.32.3)\n",
      "Collecting retry>=0.9.2 (from lancedb)\n",
      "  Downloading retry-0.9.2-py2.py3-none-any.whl (8.0 kB)\n",
      "Requirement already satisfied: tqdm>=4.27.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (4.66.4)\n",
      "Requirement already satisfied: pydantic>=1.10 in /usr/local/lib/python3.10/dist-packages (from lancedb) (2.8.0)\n",
      "Requirement already satisfied: attrs>=21.3.0 in /usr/local/lib/python3.10/dist-packages (from lancedb) (23.2.0)\n",
      "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from lancedb) (24.1)\n",
      "Requirement already satisfied: cachetools in /usr/local/lib/python3.10/dist-packages (from lancedb) (5.3.3)\n",
      "Collecting overrides>=0.7 (from lancedb)\n",
      "  Downloading overrides-7.7.0-py3-none-any.whl (17 kB)\n",
      "Collecting pyarrow<15.0.1,>=12 (from pylance==0.13.0->lancedb)\n",
      "  Downloading pyarrow-15.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (38.3 MB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m38.3/38.3 MB\u001b[0m \u001b[31m12.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.10/dist-packages (from pylance==0.13.0->lancedb) (1.25.2)\n",
      "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.10->lancedb) (0.7.0)\n",
      "Requirement already satisfied: pydantic-core==2.20.0 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.10->lancedb) (2.20.0)\n",
      "Requirement already satisfied: typing-extensions>=4.6.1 in /usr/local/lib/python3.10/dist-packages (from pydantic>=1.10->lancedb) (4.12.2)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (3.3.2)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (3.7)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (2.0.7)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.31.0->lancedb) (2024.6.2)\n",
      "Requirement already satisfied: decorator>=3.4.2 in /usr/local/lib/python3.10/dist-packages (from retry>=0.9.2->lancedb) (4.4.2)\n",
      "Collecting py<2.0.0,>=1.4.26 (from retry>=0.9.2->lancedb)\n",
      "  Downloading py-1.11.0-py2.py3-none-any.whl (98 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.7/98.7 kB\u001b[0m \u001b[31m8.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hInstalling collected packages: ratelimiter, pyarrow, py, overrides, deprecation, retry, pylance, lancedb\n",
      "  Attempting uninstall: pyarrow\n",
      "    Found existing installation: pyarrow 16.1.0\n",
      "    Uninstalling pyarrow-16.1.0:\n",
      "      Successfully uninstalled pyarrow-16.1.0\n",
      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
      "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 15.0.0 which is incompatible.\u001b[0m\u001b[31m\n",
      "\u001b[0mSuccessfully installed deprecation-2.1.0 lancedb-0.9.0 overrides-7.7.0 py-1.11.0 pyarrow-15.0.0 pylance-0.13.0 ratelimiter-1.2.0.post0 retry-0.9.2\n"
     ]
    }
   ],
   "source": [
    "!pip install datasets\n",
    "!pip install transformers\n",
    "!pip install torch\n",
    "!pip install lancedb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 361,
     "referenced_widgets": [
      "1f02aec59eb04b2e873731e795406abe",
      "995f76f7635e45ed8948c1618bedbd9b",
      "e0ce157c3b274a5abd952d5515c37c70",
      "7e3dfd77b9224f43b816bc78943fcc4d",
      "f83187d443c94cec95658106736cdc22",
      "c7c890a6cbea4dd7a7cd403159626a71",
      "9568f6f5711046b1942521edb28c8190",
      "c78c1963e3244dfd9bcaf5852ee2769b",
      "add3c8bffca840a68dade1f89e65916c",
      "58c071b5300c43d0a41716760e8079fc",
      "5cc22c407d8246e787cef4f9f875b64b",
      "d22bd927dda740dcb74183fe4c54b7d8",
      "2ef6a34add5c486d861afa8dea774825",
      "b483fb599bdb474e87ceaf522c289ae1",
      "059fc21694e94ccb8c2dfb4255bf527e",
      "187f91bd1b1e41d48ace6365083c5087",
      "5683f6a74ba240d68b8669fe8ea57e02",
      "dcd4222a41734ccca8d09d8f8cf52da7",
      "d493e8b6833b4c67a4c9501f7e674152",
      "4c99a46609d0422981ab6ae626622b4d",
      "e08fea8e4b4e4196a9cd46bbcfedddc7",
      "72c6f68df909430d892efb5f2456b1be",
      "62386fe9060e470290d0130b7c5a7238",
      "11ff18c10944404a80953faf74327ce9",
      "0846aa03430a4d02bb1b0a0953a6caab",
      "5fc08303262748fe9f27462df799cb82",
      "5e65707ec5974579867af83852bc8316",
      "ce090ed8f3534a0c8d400005add62eaa",
      "ef54ceb328b4476297527227beee97ac",
      "8e836cd77269434e8d357e7cab7197b2",
      "4a58cdd528fb4a04b9a6c33dfd52728f",
      "9c6e9c18153e48fea6862309ff769ab9",
      "17a324ae28df41e78cbeead4da5792a5",
      "a2ed30a97cc142288da82b9566ca9d45",
      "a277ecdcc30a4d8994f7a804f76be812",
      "44fc658045eb40c78c5e59a1c3346fb8",
      "f070fa459fce402785b421bdf2a78b1c",
      "ecbea823cda848b795622f17624fc212",
      "456b72feeb754ef3918fabfcbc48da90",
      "322958a4c68d41b791c85ae588531731",
      "e7251577425b4ff7ba7c504cf8051a4c",
      "7964da0451ef4847b4ff425ad63453b2",
      "a1e5fab3cbba4554a66f4eb390aa8cfa",
      "40a537f833084f2cb28997d4cf095abe",
      "32a0f9e09e454f21a1d3adbb1833968b",
      "9dc5f3805e1a4e498e78e0b4e996c36c",
      "4dfe744e6bf247abb3444648ac50d1c6",
      "46c4dcafbb0f4495b29b594704524ad0",
      "45e65d18e86d4a9bb1cfad86dd9fa2c3",
      "db535d1a2dc14a8ea2a8502ee3246146",
      "7cea2d8716e642508d17ea0005dbb99e",
      "6088f64044674783b872a89b0ac2d6a0",
      "a0950ac0144f4964ab85633947a657be",
      "fb1d9819f508498c9374e52da815af27",
      "0319c3b89fb549d7b400cc9e3c7099e2"
     ]
    },
    "id": "Bprp6zdzm3P7",
    "outputId": "58cd1daf-0848-4a31-c9e7-64ee3edf4591"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n",
      "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
      "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
      "You will be able to reuse this secret in all of your notebooks.\n",
      "Please note that authentication is recommended but still optional to access public models or datasets.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1f02aec59eb04b2e873731e795406abe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading readme:   0%|          | 0.00/9.98k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d22bd927dda740dcb74183fe4c54b7d8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/119M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "62386fe9060e470290d0130b7c5a7238",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/23.8M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a2ed30a97cc142288da82b9566ca9d45",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating train split:   0%|          | 0/50000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "32a0f9e09e454f21a1d3adbb1833968b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['img', 'fine_label', 'coarse_label'],\n",
       "    num_rows: 10000\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "imagedata = load_dataset(\"uoft-cs/cifar100\", split=\"test\")\n",
    "\n",
    "imagedata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "PIOWR50hm3P-",
    "outputId": "d5c7a4e1-7ffc-4120-f561-31e4dfec02b7"
   },
   "outputs": [],
   "source": [
    "set(imagedata[\"fine_label\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "6-XhuZFgm3P-",
    "outputId": "4edf3803-9b64-4c19-938f-6a21308e1556"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['apple',\n",
       " 'aquarium_fish',\n",
       " 'baby',\n",
       " 'bear',\n",
       " 'beaver',\n",
       " 'bed',\n",
       " 'bee',\n",
       " 'beetle',\n",
       " 'bicycle',\n",
       " 'bottle',\n",
       " 'bowl',\n",
       " 'boy',\n",
       " 'bridge',\n",
       " 'bus',\n",
       " 'butterfly',\n",
       " 'camel',\n",
       " 'can',\n",
       " 'castle',\n",
       " 'caterpillar',\n",
       " 'cattle',\n",
       " 'chair',\n",
       " 'chimpanzee',\n",
       " 'clock',\n",
       " 'cloud',\n",
       " 'cockroach',\n",
       " 'couch',\n",
       " 'cra',\n",
       " 'crocodile',\n",
       " 'cup',\n",
       " 'dinosaur',\n",
       " 'dolphin',\n",
       " 'elephant',\n",
       " 'flatfish',\n",
       " 'forest',\n",
       " 'fox',\n",
       " 'girl',\n",
       " 'hamster',\n",
       " 'house',\n",
       " 'kangaroo',\n",
       " 'keyboard',\n",
       " 'lamp',\n",
       " 'lawn_mower',\n",
       " 'leopard',\n",
       " 'lion',\n",
       " 'lizard',\n",
       " 'lobster',\n",
       " 'man',\n",
       " 'maple_tree',\n",
       " 'motorcycle',\n",
       " 'mountain',\n",
       " 'mouse',\n",
       " 'mushroom',\n",
       " 'oak_tree',\n",
       " 'orange',\n",
       " 'orchid',\n",
       " 'otter',\n",
       " 'palm_tree',\n",
       " 'pear',\n",
       " 'pickup_truck',\n",
       " 'pine_tree',\n",
       " 'plain',\n",
       " 'plate',\n",
       " 'poppy',\n",
       " 'porcupine',\n",
       " 'possum',\n",
       " 'rabbit',\n",
       " 'raccoon',\n",
       " 'ray',\n",
       " 'road',\n",
       " 'rocket',\n",
       " 'rose',\n",
       " 'sea',\n",
       " 'seal',\n",
       " 'shark',\n",
       " 'shrew',\n",
       " 'skunk',\n",
       " 'skyscraper',\n",
       " 'snail',\n",
       " 'snake',\n",
       " 'spider',\n",
       " 'squirrel',\n",
       " 'streetcar',\n",
       " 'sunflower',\n",
       " 'sweet_pepper',\n",
       " 'table',\n",
       " 'tank',\n",
       " 'telephone',\n",
       " 'television',\n",
       " 'tiger',\n",
       " 'tractor',\n",
       " 'train',\n",
       " 'trout',\n",
       " 'tulip',\n",
       " 'turtle',\n",
       " 'wardrobe',\n",
       " 'whale',\n",
       " 'willow_tree',\n",
       " 'wolf',\n",
       " 'woman',\n",
       " 'worm']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# labels names\n",
    "labels = imagedata.info.features[\"fine_label\"].names\n",
    "print(len(labels))\n",
    "labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "fjYFjp6Gm3P-",
    "outputId": "4311a0ed-7d5d-40d4-dae6-6dc62c6f245d"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['This is an image for apple',\n",
       " 'This is an image for aquarium_fish',\n",
       " 'This is an image for baby',\n",
       " 'This is an image for bear',\n",
       " 'This is an image for beaver',\n",
       " 'This is an image for bed',\n",
       " 'This is an image for bee',\n",
       " 'This is an image for beetle',\n",
       " 'This is an image for bicycle',\n",
       " 'This is an image for bottle',\n",
       " 'This is an image for bowl',\n",
       " 'This is an image for boy',\n",
       " 'This is an image for bridge',\n",
       " 'This is an image for bus',\n",
       " 'This is an image for butterfly',\n",
       " 'This is an image for camel',\n",
       " 'This is an image for can',\n",
       " 'This is an image for castle',\n",
       " 'This is an image for caterpillar',\n",
       " 'This is an image for cattle',\n",
       " 'This is an image for chair',\n",
       " 'This is an image for chimpanzee',\n",
       " 'This is an image for clock',\n",
       " 'This is an image for cloud',\n",
       " 'This is an image for cockroach',\n",
       " 'This is an image for couch',\n",
       " 'This is an image for cra',\n",
       " 'This is an image for crocodile',\n",
       " 'This is an image for cup',\n",
       " 'This is an image for dinosaur',\n",
       " 'This is an image for dolphin',\n",
       " 'This is an image for elephant',\n",
       " 'This is an image for flatfish',\n",
       " 'This is an image for forest',\n",
       " 'This is an image for fox',\n",
       " 'This is an image for girl',\n",
       " 'This is an image for hamster',\n",
       " 'This is an image for house',\n",
       " 'This is an image for kangaroo',\n",
       " 'This is an image for keyboard',\n",
       " 'This is an image for lamp',\n",
       " 'This is an image for lawn_mower',\n",
       " 'This is an image for leopard',\n",
       " 'This is an image for lion',\n",
       " 'This is an image for lizard',\n",
       " 'This is an image for lobster',\n",
       " 'This is an image for man',\n",
       " 'This is an image for maple_tree',\n",
       " 'This is an image for motorcycle',\n",
       " 'This is an image for mountain',\n",
       " 'This is an image for mouse',\n",
       " 'This is an image for mushroom',\n",
       " 'This is an image for oak_tree',\n",
       " 'This is an image for orange',\n",
       " 'This is an image for orchid',\n",
       " 'This is an image for otter',\n",
       " 'This is an image for palm_tree',\n",
       " 'This is an image for pear',\n",
       " 'This is an image for pickup_truck',\n",
       " 'This is an image for pine_tree',\n",
       " 'This is an image for plain',\n",
       " 'This is an image for plate',\n",
       " 'This is an image for poppy',\n",
       " 'This is an image for porcupine',\n",
       " 'This is an image for possum',\n",
       " 'This is an image for rabbit',\n",
       " 'This is an image for raccoon',\n",
       " 'This is an image for ray',\n",
       " 'This is an image for road',\n",
       " 'This is an image for rocket',\n",
       " 'This is an image for rose',\n",
       " 'This is an image for sea',\n",
       " 'This is an image for seal',\n",
       " 'This is an image for shark',\n",
       " 'This is an image for shrew',\n",
       " 'This is an image for skunk',\n",
       " 'This is an image for skyscraper',\n",
       " 'This is an image for snail',\n",
       " 'This is an image for snake',\n",
       " 'This is an image for spider',\n",
       " 'This is an image for squirrel',\n",
       " 'This is an image for streetcar',\n",
       " 'This is an image for sunflower',\n",
       " 'This is an image for sweet_pepper',\n",
       " 'This is an image for table',\n",
       " 'This is an image for tank',\n",
       " 'This is an image for telephone',\n",
       " 'This is an image for television',\n",
       " 'This is an image for tiger',\n",
       " 'This is an image for tractor',\n",
       " 'This is an image for train',\n",
       " 'This is an image for trout',\n",
       " 'This is an image for tulip',\n",
       " 'This is an image for turtle',\n",
       " 'This is an image for wardrobe',\n",
       " 'This is an image for whale',\n",
       " 'This is an image for willow_tree',\n",
       " 'This is an image for wolf',\n",
       " 'This is an image for woman',\n",
       " 'This is an image for worm']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# generate sentences\n",
    "clip_labels = [f\"This is an image for {label}\" for label in labels]\n",
    "clip_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 273,
     "referenced_widgets": [
      "ab5b15d685374bae93ae855987640062",
      "983aa7c8246d45bc861de462a39de0f1",
      "0e6d4659cf2745a3a77288fa7ab69ca3",
      "3371577d4e214cb7bbe9dcd9875bd0c6",
      "db637e06ab2645bfb871d00e9dad9472",
      "bccff3bf099341a993ee9b858a4b6d3e",
      "fb3e3274b9424110bfc8ba119d4f546a",
      "9a909dfee1124ce085bb03b57a8fdbc6",
      "26a966364dff4c34a26d96ecb421c511",
      "e3191adc72dd4d969027bf033394f823",
      "49015e83281b486c9d38dac59f8ddd1a",
      "f2ce9942dde7407388f05156765b793a",
      "32c0a49dc9da4ea4a3db5bc58f5dbc06",
      "d63228bed8ea48f495b3c6aa42ca67b6",
      "0cfa21cb2b6849da87a9b0e5dca3a704",
      "d29fe2e502c54d5f932beb236316d509",
      "ee2061247bda43479e563152e755250d",
      "58c41766d99d4c86a22832a5f2a08f46",
      "0ae7c9c963334dd4afab2e906e1da869",
      "79c4e32ed15a4a039b730461b0d4417e",
      "ba310787d31243edbb34006c44756a57",
      "15ff2c607d3f44c5a45daccb264435d6",
      "2f7f5744433c4513b0300fdf0fd0a7e0",
      "b2a4c91aa4ec417a9f1ad84309f12305",
      "db1258b671e84c95bdd49afdae9ea592",
      "c67e7dd4931248dd9fcb7e5b7bb7f550",
      "b2a42773c92c491980df643d0ba86d36",
      "382f0bbd49574ff3bc9e8d1c80f3619e",
      "d134afd006874440a22e9387fe8097f2",
      "12daa541e6f64b3b9074421d23f97629",
      "e524a1465dbc4893b8c2639fb721f29b",
      "b8ed6c12d951462988011b883e9aa0c6",
      "faff56fd2e73449c9c1320fd4e30d4a8",
      "ff063f93925245e083713f1ab0e76c30",
      "eb4adcce8b5c4143afdd0ab424ea62c3",
      "7dd149bd40d84442bf35b7a158aa81df",
      "55abf1997ad54432957c4a8e33201588",
      "257d45c34f7348bda929c5d891285bd0",
      "b11b3474084d4ef99063ab7fa9417bab",
      "83acaced195a4336ba6f9c5b1ec9eaa4",
      "fc12c42748714969b906c6119d31502c",
      "d079219a7b4b46a98de5d34d44471f65",
      "bf2de599b3e84478bbd8a9ca1e24be58",
      "ab1e2dabe572430e8e12ec485d44e1c9",
      "89e36ea2dcf44340b61bdb7f686a4bab",
      "be0bf5a5a568450c87df0d94d30ae5e9",
      "90b5df728aee45f694da1e73c4ca2c06",
      "d6f4c6c319074e879ccd03b9a13f6f38",
      "0b1bc04b75ec45efb398ea15a623340b",
      "9552c8c0f7d1465098645457110d166e",
      "27274d40ee664855b2d4d2b8b8239644",
      "5378f73084fb4ca48eec86bc821bfe64",
      "3dfa28af9dde41f0ab27439af3f604e0",
      "0f7bf73fb19a4b4cba9716a39cb99042",
      "ef446390ab414a1db68a9cc44dd4ad45",
      "6d67f428248442bab9509472748a6caa",
      "948287163e3846d3a541a44e8f86eca0",
      "70429e047dd24c7eaae2488bdf5522df",
      "eeed8c2cd026489faf124f1714dc6f6e",
      "b53fb1b8922e4eeaa77a9f5f9a57232d",
      "6e3c4b91f37c4818864f7469612602fb",
      "810c5163652547829a5343fa80efa75e",
      "90c432b2f65b497a8181e3fed56bddb1",
      "d1706515c5a94ed48b1e14c12239b635",
      "8aa5ad1ef61247baa44c5908f4ff7ff6",
      "604c7add75764547adba77993cfaefcf",
      "85da268e7ef44fd1b47058ff931f96b7",
      "6910bc1d16d443a3a839a0d991152c64",
      "b8be9a6284214de48ae61d2a77fbab42",
      "74ec1a937b7a4946abd7cb327eb814bb",
      "48944041e08f4d778f0a6ea01d65c673",
      "c3821ccbd60f48299d184854af9145d8",
      "bfb083a07b5f4a5786d8e02ea58e4731",
      "bd3bb9c81a1a4dbd972a7552642e80df",
      "95b8707642574bca8f8ae2009f6e7128",
      "5b530a1622eb4ed48c00ada6f8ffe577",
      "c061e99e48844e65b9e47f8776578f01",
      "d34ba016218d4df89a52a3e3a58ac961",
      "649e69a1df7d422599fc709a222f8cb1",
      "c784b3b207cd4c70bf82e8eb6f0d9b6b",
      "4fa9de3d581445c89adbaa2af6225d2e",
      "b56e468be14a4c1890d71f1261cfcbf7",
      "58dc8b84abf74649949442f09b1689c6",
      "31748e59d7cb427fb4e63e04277d7558",
      "663cced6fa204e4aa6e54906b2aeffaa",
      "fa4dbc797ef2425d8641744201389e54",
      "873e4e5a2af74ff8b1d62969b433eaa5",
      "f3a19e06faa1494ca97fb75b3de56e50"
     ]
    },
    "id": "Ld2HC8a_m3P-",
    "outputId": "f84f892a-d35e-4b87-fda5-cea6f54dfa2e"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ab5b15d685374bae93ae855987640062",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f2ce9942dde7407388f05156765b793a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/905 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2f7f5744433c4513b0300fdf0fd0a7e0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "vocab.json:   0%|          | 0.00/961k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ff063f93925245e083713f1ab0e76c30",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "89e36ea2dcf44340b61bdb7f686a4bab",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6d67f428248442bab9509472748a6caa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "85da268e7ef44fd1b47058ff931f96b7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/4.52k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d34ba016218d4df89a52a3e3a58ac961",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/1.71G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# initialization\n",
    "from transformers import CLIPProcessor, CLIPModel\n",
    "\n",
    "model_id = \"openai/clip-vit-large-patch14\"\n",
    "\n",
    "processor = CLIPProcessor.from_pretrained(model_id)\n",
    "model = CLIPModel.from_pretrained(model_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 36
    },
    "id": "i7Zfo6Z6m3P_",
    "outputId": "ad370ae7-8eb5-4097-8291-a4d85ebc436e"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.google.colaboratory.intrinsic+json": {
       "type": "string"
      },
      "text/plain": [
       "'cpu'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "# if you have CUDA set it to the active device like this\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "# move the model to the device\n",
    "model.to(device)\n",
    "\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "0NX4n3Jym3QA",
    "outputId": "04fbbd14-f1b8-4523-ec59-cf3e6c1ac037"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Token ID : tensor([49406,   589,   533,   550,  2867,   556,  3055, 49407, 49407, 49407]), Text : <|startoftext|>this is an image for apple <|endoftext|><|endoftext|><|endoftext|>\n",
      "Token ID : tensor([49406,   589,   533,   550,  2867,   556, 16814,   318,  2759, 49407]), Text : <|startoftext|>this is an image for aquarium _ fish <|endoftext|>\n",
      "Token ID : tensor([49406,   589,   533,   550,  2867,   556,  1794, 49407, 49407, 49407]), Text : <|startoftext|>this is an image for baby <|endoftext|><|endoftext|><|endoftext|>\n",
      "Token ID : tensor([49406,   589,   533,   550,  2867,   556,  4298, 49407, 49407, 49407]), Text : <|startoftext|>this is an image for bear <|endoftext|><|endoftext|><|endoftext|>\n",
      "Token ID : tensor([49406,   589,   533,   550,  2867,   556, 22874, 49407, 49407, 49407]), Text : <|startoftext|>this is an image for beaver <|endoftext|><|endoftext|><|endoftext|>\n"
     ]
    }
   ],
   "source": [
    "# create label tokens\n",
    "label_tokens = processor(text=clip_labels, padding=True, return_tensors=\"pt\").to(device)\n",
    "\n",
    "# Print the label tokens with the corresponding text\n",
    "for i in range(5):\n",
    "    token_ids = label_tokens[\"input_ids\"][i]\n",
    "    print(\n",
    "        f\"Token ID : {token_ids}, Text : {processor.decode(token_ids, skip_special_tokens=False)}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XOTmBUYWm3QB"
   },
   "source": [
    "## Creating the embeddings!\n",
    "\n",
    "When you're working with zero-shot image classification using CLIP (Contrastive Language-Image Pre-Training), you're essentially leveraging the ability of the CLIP model to understand both images and text by mapping them into a shared embedding space. This shared space allows the model to compute the similarity between text and image embeddings, making it ideal for zero-shot tasks.\n",
    "\n",
    "Given that you're using CLIP for zero-shot classification, it's important to use the embeddings generated by the CLIP model for both images and text. The reason for this is that CLIP has been specifically trained to align these two modalities (text and images) in the same embedding space. Using different embeddings would break this alignment and undermine the model's ability to perform zero-shot classification effectively.\n",
    "\n",
    "However, if you want to explore other embeddings, you could consider using different pre-trained models for other tasks or experiments, but for the zero-shot classification task with CLIP, sticking with the CLIP-generated embeddings is essential."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "NJH-9a0Om3QB",
    "outputId": "8cc4355f-905e-4878-f85a-2233f96ae6f0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': tensor([[49406,   589,   533,   550,  2867,   556,  3055, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 16814,   318,  2759, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  1794, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  4298, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 22874, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  2722, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  5028, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 16534, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 11652, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  5392, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  3814, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  1876, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  2465, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  2840, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  9738, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 21914, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,   753, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  3540, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 27111, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 13644, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  4269, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 10543,  1072, 14080, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  6716, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  3887, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,   622,   916, 31073, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 12724, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 18362, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 24757, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  1937, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 15095, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 16464, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 10299, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  8986,  2759, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  4167, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  3240, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  1611, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 33313, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  1212, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 25513, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 13017, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 10725, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 11024,   318, 30895, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 15931, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  5567, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 17221, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 13793, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,   786, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 10570,   318,  2677, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 10297, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  3965, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  9301, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 13011, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  7221,   318,  2677, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  4287, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 18678, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 22456, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  8612,   318,  2677, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 18820, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 15382,   318,  4629, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  7374,   318,  2677, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 10709, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  5135, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 15447, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,   817,  5059,   715, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 38575, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 10274, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 29516, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  3077, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  1759, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  8383, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  3568, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  2102, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 10159, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  7980, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 12101,   342, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 42194, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  3075, 11187,  1284, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 23132, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  8798, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  7622, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 14004, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 34268, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 21559, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  2418,   318,  8253, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  2175, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  6172, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 17243, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  8608, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  6531, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 14607, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  3231, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 14853, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 28389, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 10912, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 15020, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 11650, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 15665,   318,  2677, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  5916, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556,  2308, 49407, 49407, 49407],\n",
      "        [49406,   589,   533,   550,  2867,   556, 10945, 49407, 49407, 49407]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])}\n"
     ]
    }
   ],
   "source": [
    "print(label_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "HdS-vWVUm3QB",
    "outputId": "34f1c6cc-4739-4b8d-d711-303ecbc2ec36"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(100, 768)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# encode tokens to sentence embeddings from CLIP\n",
    "\n",
    "with torch.no_grad():\n",
    "    label_emb = model.get_text_features(\n",
    "        **label_tokens\n",
    "    )  # passing the label text as in \"a photo of a cat\" to get it's relevant embedding from clip model\n",
    "\n",
    "# Move embeddings to CPU and convert to numpy array\n",
    "label_emb = label_emb.detach().cpu().numpy()\n",
    "label_emb.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tLbzGPwym3QC"
   },
   "source": [
    "We got the 768 dimensional vector for each of our 10 text classes sentences.\n",
    "\n",
    "Now there needs to be the normalizing the embeddings to make sure we get the better results when we are doing the dot product, and we can do this by dividing each of the vectors by the square root of the sum of the squares of each of the vectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "hAZs9Zh2m3QC",
    "outputId": "04e7bc8e-dd1b-4937-b570-fb336efb180b"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-15.682584, 13.341022)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "label_emb.min(), label_emb.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ggHbvVpVm3QC",
    "outputId": "5843b2fb-b5a4-47d8-ed7a-efdb5bdaf409"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-0.3624336, 0.39170283)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "# normalization\n",
    "label_emb = label_emb / np.linalg.norm(label_emb, axis=0)\n",
    "label_emb.min(), label_emb.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 49
    },
    "id": "871ed-tBm3QC",
    "outputId": "73bba5c9-ef85-416e-8454-ffeb6f0151b0"
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAJe0lEQVR4nC3NWW9cVwEA4LPfde7Mnc2OHddO4rRNt9BQt0lRQ6l4AInfxxM8AD8ApD4gIQS0VKUVKEqT2LET2jh27JmxZ73b2c/hhe8PfLDRBlqHPPDQaWtOT84hRP/66uv79z/pr6/9+re/+cc/v7i5sy0EXxXFjWs3bm7dYAw++O7r58eHyCOtkafQUgG88h6Ml+OA4YiwyaKUXEc4IpRgDyAEAEBUNfV4MqY08BB4740xvK6xB1bqdpR6ZfpZZ5j1rdZvbL3Ta/WRt5Kry+XME81FwWVN86EGzaqaa6jjVqqXliznS0ppFIXa6Ef7+w8ePlTKVKvy2vYOwHg1XxIAW1E86Pc/uHNn9/rNyWl9Npoo08k67cX0lQd8c7OX5ZkzTVMsFNLfnx8KrbqJdw3GDSSr5WJtfV1qdTmffvuffx8fnxhljdRccAg8dC4OorXhcG9vb2dnG0L6+MnjkwulrGMx+vabg/Hoh/6wOxiuFbPLTkR7g0HevTYw6rQ4NtYRGhMp5GK1LOv6ycHjoiw8AMoYjJBSGngQUvbpJ/c//uRev9dL04Rz5aEnLESeAWSWdX168apUS2X94dE+sCIAcWeYbrzRrsraCmRrR7rd/HwyObuYXEwvZ4tZ3XDokdZuMhnv7t746f37b731tgaGBkGUJNqCQS8rZpYLyDVvylkcwX4v7eVRkpFGAOggV3KxWGGMi6qBxhMPAJfi4OBgWk8LVSCEoHEhY61WkuetOHxDKWe8JwhLLqtVgwGMQyaEcIZ7rUIaf/DRvflsipFJQ8pI7Akw3hPMrK/SmJEXJyc/HB+/eHnMbQ2c99oMemt77+/dfu+2MboRAviAEFRXjXOurqS2BhKonRJGsIDVBkjrolYriOK6qqS2SojhlbXBsJ9nQ6sBeXL09Nnz57PFnDEUEbb7xu69vXtXN7Yc8NYDFkQYhgh5KaUH3jnnIPDYeWIxw628U/vqu/2Drc0NGqUMEOQjX/ubW9e3tweGwFejMzKZXk7nM6VUzNK7dz66u3cvz3LrgbIaEUIQpiiEwEopgyCoilJphRlIWkxaDDEklEZRgkmQZj0SWasQQ+jibOxtGfUyjxxZLBay4Xmr/atf/PLO7fedBQgzShmyynmNEIYIeWs88M46qRRGMKAojslyaTtZGxITxikjQZbkumjWtzdNMZuND11Q2GaSdTpENnxnc+vT+/e3rm5TFjoHIQ2Utc57ABFCxDvorHXWKacYpSGlwGlK4lOleFkC4FQjXOwSmvIAQRxAGkhncIdJI88mpyRg4Vu33r2+fcNYByEBEChtnHPee4yhdQZ4B7z23tcNNxZjhKFXIaPF4mJ08kJ6dWX7WkiIsB4gwBU3ShS8HkDGRa2EJpXg2kJtgPfQamedAwB45521iCLgHYRe60YJvZqXUgJKAgSM5vpyfIq8dkohAJQUznvrNS+mq4sziHFZympWJDQmRVU6Y4H1WisVKP//wAHgtPIAOkZJK441L8plAQGlmHgNi2U9Hc25EEGaNbWfTyuCMRLO6QIKHmCSZ50IOKcckbJRkhupjFVCNBAAjAkADmOfJFEcB4xRQkLt4U3CgEePHh+VtWi03H3z9fZaRONEaJKGIbIu4Y32iMeIqzIg2DBS8ooIUTujA0wBdBB6SkmSJFkrYQwjDDACDgLlUXutP7hCiuVqbRUPXCuO2m/f2V5OLy8Xo+6VmGKPIa4aPb+s14dbr8bH3zz9M8KpdY5A5LxzquadtTzqpiELMCEIQQS8t8Y4BxCEiASEIAfyNHlt2EqSdru1bowZvRyPVxVbf+FQ4xwZuO5GvrUzfPv61d7D53+5WNQkoAQAJ5pmejHtrOUBYwB44JxzwAMAAPDAO2ctF1Io1XBd15MXR5TE82jCuZ7PViY/4NnfBVoYhRN9TRzf3X9UczyzUi5XC8wioq0OEtYddpQUrpFOaKWsEkoLJWW9VPW8LHGp7GLlK6lroZqFlkZJKIVZiLJ169mt9RoENXBMX0yffvVVfdZXkajimhcKEksstEFKu+u9o6fPnn23HxkIlPHGO2kgsvujl59/+eVn7324GUZMWuiIdVAIW9dcSlULsXhqZlJoZJ3lenlRnEBTQcm1hPZKdyNKAgK8ds5gQvefHJ08+/7+j+9QCCijQUCjKJ0eLi9nK0fwjdu3YoKFMI+OfnhZLkez8Wq5dBaLCwPPc+D7BmhhZeTCpp65qP7Jzz7Kh11MIMmSEEEPoA/CaLB+5fV33yHYaKfm01nYzrizex/d3djZAXmsmeONq1vxBYJNK9bYHx09zPv53R/9/NrVG0Ecff/qpLicISJ9WrfWezQNtVPk4w/3+r0Ol0WWtwginbV+kGCu+Plivt7Nw1aWYtvu5TfffzfJw7Jonp58kXX6V7dfWxWL0eJVmGdxrxO0GYYmYv68nLa7mRDodDJOVWK8JsP+QIpqPDlG2CICL5bzwEdSqkWlkrKGhGysDeIsndcVSJm2uKnl6GwsRNrttj6++5lGwDq7LAuguOHTQZcqq8uisWZCwRqLKPEaNrIajU8gjqM0eHTwpL0+iMKW1KisFcG0224770YXl9I5VdqXL09fnLxg0bXru1fbnaxSKwXqazfeJpY8+HZc8yWgMYY+VnArbEdxRAIzBNCrylsHrTO//8PvPAa3b7//2uZrZbWYL5b9fK3hDb+UBFFnvAKyqFdCyyhM5vPpslhYyPf3H3hlL2bnZbNaVa8ooTtb3SHJYpoS32TGaM7F5XwyX02PDv67XM0PH+9vbm1CyIaDq2+++d7ldP6nz/9IArSxtTGaTGmoR6MXf/vrsmjq09GpkAUmrp2F7TQcDHoQizTFVbM6PHyadwcEAuKcBR4HNDVy/t47H9Z1VVTz+WzFmyUGWV0KRtzzo2eT+TkLEYQQIzyj9BWmkODzyZm18srGwEO8ahQqQRDQeb0ykNmE1uWEIEQCBikJojjN8/7ru7eaplkVs9nicjpbLObFs8P/xiHijfQWSK6ghwhC4QWEEGJMURjSyCosGtfpJVHcSpI4SoIwpWEYIYwIgtABiBHEGDIaxWHYaXWGg8G23uZclmWtlZxNRzevvylMo4zQ3BlltdEOAMaCKI6yrNXptNvdOB/GrVactpIgohYaZRRCgHhoATCY0jhiTVMD4IEHBCAcxFEQd9qZMXpt0L1+fZcryYVoVjVvhHYGYByEURIlcZwkScJCjJkF0BGCMIEAWYN0EOD/Aau6KgiqSnuZAAAAAElFTkSuQmCC",
      "text/plain": [
       "<PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "index = random.randint(0, len(imagedata) - 1)\n",
    "selected_image = imagedata[index][\"img\"]\n",
    "selected_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "CPLPVY20m3QD",
    "outputId": "52018039-07a0-4121-9c2d-30dfa35920e2"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 3, 224, 224])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "image = processor(text=None, images=imagedata[index][\"img\"], return_tensors=\"pt\")[\n",
    "    \"pixel_values\"\n",
    "].to(device)\n",
    "image.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ETn6TbD_m3QD",
    "outputId": "018f2ce7-98ac-4d13-c297-af92290bb4f4"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[ 1.7844,  1.7844,  1.7844,  ..., -1.2083, -1.2083, -1.2083],\n",
       "          [ 1.7990,  1.7990,  1.7990,  ..., -1.2229, -1.2229, -1.2083],\n",
       "          [ 1.7990,  1.7990,  1.7990,  ..., -1.2229, -1.2229, -1.2229],\n",
       "          ...,\n",
       "          [ 0.6603,  0.6603,  0.6603,  ...,  0.4997,  0.5143,  0.5143],\n",
       "          [ 0.6603,  0.6603,  0.6603,  ...,  0.5143,  0.5289,  0.5289],\n",
       "          [ 0.6603,  0.6603,  0.6603,  ...,  0.5289,  0.5435,  0.5435]],\n",
       "\n",
       "         [[ 2.0149,  2.0149,  2.0149,  ..., -0.9267, -0.9267, -0.9267],\n",
       "          [ 2.0149,  2.0149,  2.0149,  ..., -0.9267, -0.9267, -0.9267],\n",
       "          [ 2.0149,  2.0149,  2.0149,  ..., -0.9417, -0.9417, -0.9417],\n",
       "          ...,\n",
       "          [ 0.5741,  0.5741,  0.5741,  ...,  0.5891,  0.6041,  0.6041],\n",
       "          [ 0.5741,  0.5741,  0.5741,  ...,  0.5891,  0.6041,  0.6041],\n",
       "          [ 0.5741,  0.5741,  0.5741,  ...,  0.5891,  0.6041,  0.6041]],\n",
       "\n",
       "         [[ 2.1032,  2.1032,  2.1032,  ..., -1.0252, -1.0252, -1.0252],\n",
       "          [ 2.1032,  2.1032,  2.1032,  ..., -1.0252, -1.0252, -1.0252],\n",
       "          [ 2.1032,  2.1032,  2.1032,  ..., -1.0252, -1.0252, -1.0252],\n",
       "          ...,\n",
       "          [ 1.1789,  1.1789,  1.1789,  ...,  0.8092,  0.8092,  0.8092],\n",
       "          [ 1.1789,  1.1789,  1.1789,  ...,  0.8234,  0.8234,  0.8234],\n",
       "          [ 1.1789,  1.1789,  1.1789,  ...,  0.8377,  0.8377,  0.8377]]]])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "aYiKaCelm3QD",
    "outputId": "3db40d2b-5bfc-4748-d566-bcd8c6426e85"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 768])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img_emb = model.get_image_features(image)\n",
    "img_emb.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "epboyMWym3QD",
    "outputId": "cb361cd5-8bfb-464c-d08c-8e0af508a1d9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "           label                                             vector  \\\n",
      "0            bus  [0.022191629, 0.15996626, 0.13757694, 0.058719...   \n",
      "1          train  [-0.025220517, 0.11286947, 0.12789312, 0.01722...   \n",
      "2           road  [-0.052865148, 0.13100702, 0.16935337, -0.0592...   \n",
      "3  aquarium_fish  [-0.007917204, 0.15597954, -0.0052733854, 0.07...   \n",
      "4   pickup_truck  [0.08430549, 0.090862826, 0.07901725, -0.23153...   \n",
      "5       mountain  [-0.026367433, 0.12044583, 0.050345775, 0.0273...   \n",
      "6      telephone  [0.17053358, 0.18755479, 0.10471857, 0.0254959...   \n",
      "7        bicycle  [0.033512242, -0.05071287, 0.088322446, 0.1234...   \n",
      "8          plain  [0.0394026, 0.13996188, 0.11271955, -0.0268946...   \n",
      "9          whale  [0.10277067, 0.050650656, 0.010288065, -0.0908...   \n",
      "\n",
      "    _distance  \n",
      "0  439.172211  \n",
      "1  446.832275  \n",
      "2  449.517883  \n",
      "3  449.519348  \n",
      "4  449.607666  \n",
      "5  449.909271  \n",
      "6  450.536194  \n",
      "7  450.587311  \n",
      "8  450.887573  \n",
      "9  450.944214  \n"
     ]
    }
   ],
   "source": [
    "import lancedb\n",
    "import numpy as np\n",
    "\n",
    "data = []\n",
    "for label_name, embedding in zip(labels, label_emb):\n",
    "    data.append({\"label\": label_name, \"vector\": embedding})\n",
    "\n",
    "db = lancedb.connect(\"./.lancedb\")\n",
    "table = db.create_table(\"my_table\", data, mode=\"Overwrite\")\n",
    "\n",
    "# Prepare the query embedding\n",
    "query_embedding = img_emb.squeeze().detach().cpu().numpy()\n",
    "# Perform the search\n",
    "results = table.search(query_embedding).limit(10).to_pandas()\n",
    "\n",
    "print(results.head(n=10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Cc8atI_tm3QD"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
